{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d658f909-e679-41e9-9c4e-e0241c719049",
   "metadata": {},
   "source": [
    "If you're not running in Saturn Cloud, you need to install these libraries:\n",
    "\n",
    "Make sure you use the latest versions\n",
    "\n",
    "```\n",
    "pip install -U transformers accelerate bitsandbytes\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fd6ab5c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['HF_HOME'] = '/run/cache/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "506fab2a-a50c-42bd-a106-c83a9d2828ea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2024-06-13 19:36:42--  https://raw.githubusercontent.com/alexeygrigorev/minsearch/main/minsearch.py\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 3832 (3.7K) [text/plain]\n",
      "Saving to: ‘minsearch.py’\n",
      "\n",
      "minsearch.py        100%[===================>]   3.74K  --.-KB/s    in 0s      \n",
      "\n",
      "2024-06-13 19:36:42 (50.3 MB/s) - ‘minsearch.py’ saved [3832/3832]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!rm -f minsearch.py\n",
    "!wget https://raw.githubusercontent.com/alexeygrigorev/minsearch/main/minsearch.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3ac947de-effd-4b61-8792-a6d7a133f347",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<minsearch.Index at 0x7f07b4322940>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import requests \n",
    "import minsearch\n",
    "\n",
    "docs_url = 'https://github.com/DataTalksClub/llm-zoomcamp/blob/main/01-intro/documents.json?raw=1'\n",
    "docs_response = requests.get(docs_url)\n",
    "documents_raw = docs_response.json()\n",
    "\n",
    "documents = []\n",
    "\n",
    "for course in documents_raw:\n",
    "    course_name = course['course']\n",
    "\n",
    "    for doc in course['documents']:\n",
    "        doc['course'] = course_name\n",
    "        documents.append(doc)\n",
    "\n",
    "index = minsearch.Index(\n",
    "    text_fields=[\"question\", \"text\", \"section\"],\n",
    "    keyword_fields=[\"course\"]\n",
    ")\n",
    "\n",
    "index.fit(documents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8f087272-b44d-4738-9ea2-175ec63a058b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def search(query):\n",
    "    boost = {'question': 3.0, 'section': 0.5}\n",
    "\n",
    "    results = index.search(\n",
    "        query=query,\n",
    "        filter_dict={'course': 'data-engineering-zoomcamp'},\n",
    "        boost_dict=boost,\n",
    "        num_results=5\n",
    "    )\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fe8bff3e-b672-42be-866b-f2d9bb217106",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rag(query):\n",
    "    search_results = search(query)\n",
    "    prompt = build_prompt(query, search_results)\n",
    "    answer = llm(prompt)\n",
    "    return answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e19b3bc4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Filesystem      Size  Used Avail Use% Mounted on\r\n",
      "overlay         100G   36G   65G  36% /\r\n",
      "tmpfs            64M     0   64M   0% /dev\r\n",
      "tmpfs           7.7G     0  7.7G   0% /sys/fs/cgroup\r\n",
      "/dev/nvme0n1p1  100G   36G   65G  36% /run\r\n",
      "tmpfs            14G     0   14G   0% /dev/shm\r\n",
      "/dev/nvme2n1    2.0G  1.1G  858M  56% /home/jovyan\r\n",
      "tmpfs            14G  120K   14G   1% /home/jovyan/.saturn\r\n",
      "tmpfs            14G   12K   14G   1% /run/secrets/kubernetes.io/serviceaccount\r\n",
      "tmpfs           7.7G   12K  7.7G   1% /proc/driver/nvidia\r\n",
      "tmpfs           7.7G  3.6M  7.7G   1% /run/nvidia-persistenced/socket\r\n",
      "tmpfs           7.7G     0  7.7G   0% /proc/acpi\r\n",
      "tmpfs           7.7G     0  7.7G   0% /sys/firmware\r\n"
     ]
    }
   ],
   "source": [
    "!df -h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "091a77e6-936b-448e-a04b-bad1001f5bb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8ef4e479",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ab8f600ab3bf4285af359b7ebecd03f6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b7dc06cbc3074d13b38b4405dd06bce1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c265985987934138bc59c84d26448efe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8e9cdf126e99479e9c112e45907a4236",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ac29dc94c034480787ee4727ccfc91d7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/1.44k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "597522b395a344d5a42cf91ff439dc53",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors.index.json:   0%|          | 0.00/53.0k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1d1ccd16d9f143b4a3bc22d17a1c7417",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b7b18d9de2814a1cb711dc1451a769d3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00001-of-00002.safetensors:   0%|          | 0.00/9.45G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "881907a9b94f4875a3fb919cfb12215d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00002-of-00002.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "50a172f264d7437fa3bf585d4a9a2b06",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "10bb13f26e0543f69b0d87ac38442d52",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tokenizer = T5Tokenizer.from_pretrained(\"google/flan-t5-xl\")\n",
    "model = T5ForConditionalGeneration.from_pretrained(\"google/flan-t5-xl\", device_map=\"auto\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "a5c8c36a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_prompt(query, search_results):\n",
    "    prompt_template = \"\"\"\n",
    "You're a course teaching assistant. Answer the QUESTION based on the CONTEXT from the FAQ database.\n",
    "Use only the facts from the CONTEXT when answering the QUESTION.\n",
    "\n",
    "QUESTION: {question}\n",
    "\n",
    "CONTEXT:\n",
    "{context}\n",
    "\"\"\".strip()\n",
    "\n",
    "    context = \"\"\n",
    "    \n",
    "    for doc in search_results:\n",
    "        context = context + f\"section: {doc['section']}\\nquestion: {doc['question']}\\nanswer: {doc['text']}\\n\\n\"\n",
    "    \n",
    "    prompt = prompt_template.format(question=query, context=context).strip()\n",
    "    return prompt\n",
    "\n",
    "def llm(prompt):\n",
    "    input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(\"cuda\")\n",
    "    outputs = model.generate(input_ids, )\n",
    "    result = tokenizer.decode(outputs[0])\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "bcc3c46b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def llm(prompt, generate_params=None):\n",
    "    if generate_params is None:\n",
    "        generate_params = {}\n",
    "\n",
    "    input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(\"cuda\")\n",
    "    outputs = model.generate(\n",
    "        input_ids,\n",
    "        max_length=generate_params.get(\"max_length\", 100),\n",
    "        num_beams=generate_params.get(\"num_beams\", 5),\n",
    "        do_sample=generate_params.get(\"do_sample\", False),\n",
    "        temperature=generate_params.get(\"temperature\", 1.0),\n",
    "        top_k=generate_params.get(\"top_k\", 50),\n",
    "        top_p=generate_params.get(\"top_p\", 0.95),\n",
    "    )\n",
    "    result = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "    return result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "7dff26d5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"Yes, even if you don't register, you're still eligible to submit the homeworks. Be aware, however, that there will be deadlines for turning in the final projects. So don't leave everything for the last minute.\""
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rag(\"I just discovered the course. Can I still join it?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb0d64f6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "saturn (Python 3)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
